import PIL
from PIL import Image
import gradio as gr
from pymilvus import Collection

from app.app_init import device, model_L14, preprocess_L14, model_H14, preprocess_H14
from app.gradio.constants import description
from app.utll.embedding_util import get_embedding


def image2image_gr():
    def clip_api(query_image=None, return_num=8, top_k_num=8, thresh_hold=0.22, user_id="1", model_name="ViT-L-14"):

        if query_image is not None and isinstance(query_image, PIL.Image.Image):
            image = query_image  # Pass PIL image directly
        else:
            return None

        if model_name == "ViT-L-14":
            collection_name = "multimodal_search_L14"
            model = model_L14
            preprocess = preprocess_L14
        elif model_name == "ViT-H-14":
            collection_name = "multimodal_search_H14"
            model = model_H14
            preprocess = preprocess_H14

        print("当前使用的模型为：", model_name)

        milvus_conn = Collection(collection_name)
        model_infos = {"model": model, "preprocess": preprocess, "device": device}

        query_emb = get_embedding("image", image, model_infos, batch_mode=False)

        # 2. 根据emb去向量相似度搜索
        # method 1
        search_param = {
            "data": query_emb,
            "anns_field": "vector",
            "param": {"metric_type": "COSINE", "params": {"nprobe": 10}, "offset": 0},
            "limit": return_num,
            "expr": f"(user_id=='{user_id}')",
            "output_fields": ["id", "user_id", "image_path"]
        }

        print(f"search_param: {search_param}")
        search_result = milvus_conn.search(**search_param)

        data = [hit.entity.to_dict() for hit in search_result[0]]
        # print(data)
        result = []
        for row in data:
            # item["frame_no"] = str(item["frame_no"]).zfill(4)
            item = {}
            item["id"] = row["entity"]["id"]
            item["user_id"] = row["entity"]["user_id"]
            item["image_path"] = row["entity"]["image_path"]
            item["score"] = row["distance"]
            result.append(item)
        filtered_result = [x for x in result if x["score"] >= thresh_hold][:top_k_num]

        print("filtered_result:", filtered_result)

        # 使用列表推导式获取全部的image_path
        image_paths_list = [item['image_path'] for item in filtered_result]

        return image_paths_list

    examples = [
        ["./app/gradio/examples/cat.jpeg", 20, 16, 0.24, "1", "ViT-H-14"],
        ["./app/gradio/examples/coffee.jpeg", 20, 16, 0.24, "1", "ViT-H-14"],
        ["./app/gradio/examples/people.jpeg", 20, 16, 0.24, "1", "ViT-H-14"],
    ]

    title = "<h1 align='center'>基于AI大模型和向量数据库的图片搜索引擎</h1>"

    with gr.Blocks() as image_block:
        gr.Markdown(title)
        gr.Markdown(description)
        with gr.Row():
            with gr.Column(scale=1):
                with gr.Column(scale=2):
                    img = gr.components.Image(label="图片", type="pil", elem_id=1)
                num = gr.components.Slider(minimum=0, maximum=200, step=1, value=50, label="返回图片数",
                                           elem_id=2)
                top_k = gr.components.Slider(minimum=0, maximum=200, step=1, value=20,
                                             label="top_k：召回的图片里面，相似分大于设定阈值的，前K张图片",
                                             elem_id=3)
                score_thresh_hold = gr.components.Slider(minimum=-1, maximum=1, step=0.01, value=0.20,
                                                   label="召回分数阈值")
                user_id_text = gr.Textbox(value="1", label="请填写user_id", interactive=True)
                model_name_radio = gr.inputs.Radio(["ViT-L-14", "ViT-H-14"], default="ViT-H-14", label="模型选择")
                btn = gr.Button("搜索", )
            with gr.Column(scale=100):
                out = gr.Gallery(label="检索结果为：").style(grid=4, height=700)
        inputs = [img, num, top_k, score_thresh_hold, user_id_text, model_name_radio]
        btn.click(fn=clip_api, inputs=inputs, outputs=out)
        gr.Examples(examples, inputs=inputs)
    return image_block


if __name__ == "__main__":
    with gr.TabbedInterface(
            [image2image_gr()],
            ["图到图搜索"],
    ) as app:
        app.launch(
            enable_queue=True,
        )
